load(":aot_triton_kernel.bzl", "aot_triton_kernel_library")
load("//:def.bzl", "cuda_copts")

cc_library(
    name= "triton_kernel",
    srcs = ["layernorm_kernels.cu"],
    hdrs = ["layernorm_kernels.h"],
    deps = [
        ":layernorm_kernel_lib",
        "//rtp_llm/cpp/cuda:cuda_host_utils",
        "//rtp_llm/cpp/utils:core_utils",
        "//rtp_llm/cpp/model_utils:model_utils",
    ],
    copts = cuda_copts(),
    include_prefix = "src",
    visibility = ["//visibility:public"],
)

aot_triton_kernel_library(
    name = "layernorm_kernel",
    output_name_tpl = "layernorm_kernel_{ty}",
    kernel_name = "_layer_norm_fwd_1pass_kernel",
    triton_script = ":layernorm_kernels.py",
    spec = "*{ty}:16, *{ty}:16, *{ty}:16, *{ty}:16, *{ty}:16, *{ty}:16, *{ty}:16, i32, i32, i32, i32, i32:16, fp32, i32, {bias}, {bias}, {constant}",
    var_map = {
        "ty": ["fp32","fp16","bf16"],
        "constant": ["1024", "2048", "4096"],
        "bias": ["0", "1"],
    },
    num_warps = [4, 8],
    grid = "M,1,1",
    groupby = "ty"
)
